GRU ================= 将多层门控循环单元 (GRU) RNN 应用于输入序列。 GRU 网络模型中有两个门:更新门和重置门。将两个连续的时间节点表示为 :math:`t - 1` 和 :math:`t`。给定一个在时刻 :math:`t` 的输入 :math:`x_t`,一个隐藏状态 :math:`h_{t-1}`,在时刻 :math:`t` 的更新门和重置门使用门控制机制计算。更新门 :math:`z_t` 用于控制前一时刻的状态信息被带入到当前状态中的程度,重置门 :math:`r_t` 控制前一状态有多少信息被写入到当前候选集 :math:`n_t` 上 对于输入序列中的每个元素,每一层计算以下函数: .. math:: :nowrap: \begin{align*} r_t &= \sigma(W_{ir}x_t + b_{ir} + W_{hr}h_{(t-1)} + b_{hr}) \\ z_t &= \sigma(W_{iz}x_t + b_{iz} + W_{hz}h_{(t-1)} + b_{hz}) \\ n_t &= \tanh(W_{in}x_t + b_{in} + r_t \odot (W_{hn}h_{(t-1)} + b_{hn})) \\ h_t &= (1-z_t) \odot n_t + z_t \odot h_{(t-1)} \end{align*} 其中 :math:`\sigma` 是 sigmoid 激活函数,:math:`\odot` 是 Hadamard 积(逐元素乘积)。:math:`W, b` 是公式中输出和输入之间的可学习权重。例如,:math:`W_{ir}, b_{ir}` 是用于将输入 :math:`x_t` 转换为 :math:`r_t` 的权重和偏置。 注意,本算子中候选门 :math:`n_t` 的计算与原始论文和Mindspore框架略有不同。在原始实现中,:math:`r_t` 和上一隐藏状态 :math:`h_{(t-1)}` 之间的 Hadamard 积 (:math:`\odot`) 在与权重矩阵 :math:`W` 相乘和加上偏置之前进行: .. math:: n_t = \tanh(W_{in}x_t + b_{in} + W_{hn}(r_t \odot h_{(t-1)}) + b_{hn}) 本算子采用 PyTorch 实现方式,是在 :math:`W_{hn}h_{(t-1)}` 之后完成的: .. math:: n_t = \tanh(W_{in}x_t + b_{in} + r_t \odot (W_{hn}h_{(t-1)} + b_{hn})) 输入: - **input** - 输入数据的地址。 - **weight_g** - 可学习的输入-隐藏权重的地址。 - **weight_r** - 可学习的隐藏-隐藏权重的地址。 - **input_bias** - 可学习的输入-隐藏偏置的地址。 - **state_bias** - 可学习的隐藏-隐藏偏置的地址。 - **hidden_state** - 初始隐藏状态的地址。 - **buffer** - 用于存储中间计算结果。 - **gru_param** - 算子计算所需参数的结构体。其各成员见下述。 - **core_mask** - 核掩码。 **GruParameter定义:** .. code-block:: c :linenos: typedef struct GruParameter { int input_size_; // 输入input中预期特征的数量 int hidden_size_; // 隐藏状态h中的特征数量 int seq_len_; // 输入batch中每个序列的长度 int batch_; // 总批次数 int output_step_; // 每次循环中output步长 int bidirectional_; // 是否为双向GRU int input_row_align_; // 输入行对齐值 int input_col_align_; // 输入列对齐值 int state_row_align_; // 隐藏状态行对齐值 int state_col_align_; // 隐藏状态列对齐值 int check_seq_len_; // 进行计算的序列长度 } GruParameter; 输出: - **output** - 输出地址。 - **hidden_state** - 最终的隐藏状态。 支持平台: ``FT78NE`` ``MT7004`` .. note:: - FT78NE 支持int8, fp32 - MT7004 支持fp16, fp32 **共享存储版本:** .. c:function:: void i8_Gru_s(int8_t *output, int8_t *input, int8_t *weight_g, int8_t *weight_r, int8_t *input_bias, int8_t *state_bias, int8_t *hidden_state, int8_t *buffer[4], GruParameter *gru_param, int core_mask) .. c:function:: void hp_Gru_s(half *output, half *input, half *weight_g, half *weight_r, half *input_bias, half *state_bias, half *hidden_state, half *buffer[4], GruParameter *gru_param, int core_mask); .. c:function:: void fp_Gru_s(float *output, float *input, float *weight_g, float *weight_r, float *input_bias, float *state_bias, float *hidden_state, float *buffer[4], GruParameter *gru_param, int core_mask); **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 41 void TestGruSMCFp32(int check_seq_len, int seq_len, int batch_size, int input_size, int bidirectional, int hidden_size, int core_mask) { int core_id = get_core_id(); int logic_core_id = GetLogicCoreId(core_mask, core_id); int core_num = GetCoreNum(core_mask); float *output = (void*)0x88000000; float *input = (void*)0x88100000; float *weight_g = (void*)0x88200000; float *weight_r = (void*)0x88300000; float *input_bias = (void*)0x88400000; float *state_bias = (void*)0x88500000; float *hidden_state = (void*)0x88600000; float** buffer = (float**)0x88700000; float *output_hidden_state = (void*)0x88800000; GruParameter* param = (GruParameter*)0x88900000; int hidden_state_batch = 1; int num_directions = 1; if (bidirectional) { hidden_state_batch = hidden_state_batch * 2; num_directions = num_directions * 2; } int input_col_align = hidden_size; int state_col_align = hidden_size; if (logic_core_id == 0) { memcpy(output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float)); memcpy(check_output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float)); buffer[0] = (void*)0x88A00000; buffer[1] = (void*)0x88B00000; buffer[2] = (void*)0x88C00000; buffer[3] = (void*)0x88D00000; param->batch_ = batch_size; param->bidirectional_ = bidirectional; param->hidden_size_ = hidden_size; param->input_col_align_ = input_col_align; param->input_size_ = input_size; param->output_step_ = batch_size * hidden_size * num_directions; param->seq_len_ = seq_len; param->state_col_align_ = state_col_align; param->check_seq_len_ = check_seq_len; } sys_bar(0, core_num); // 初始化参数完成后进行同步 fp_Gru_s(output, input, weight_g, weight_r, input_bias, state_bias, output_hidden_state, buffer, param, core_mask); } void main() { int check_seq_len = 2; int seq_len = 2; int batch_size = 2; int input_size = 2; int bidirectional = 0; int hidden_size = 2; int core_mask = 0b1111; TestGruSMCFp32(check_seq_len, seq_len, batch_size, input_size, bidirectional, hidden_size, core_mask); } **私有存储版本:** .. c:function:: void i8_Gru_p(int8_t *output, int8_t *input, int8_t *weight_g, int8_t *weight_r, int8_t *input_bias, int8_t *state_bias, int8_t *hidden_state, int8_t *buffer[4], GruParameter *gru_param, int core_mask) .. c:function:: void hp_Gru_p(half *output, half *input, half *weight_g, half *weight_r, half *input_bias, half *state_bias, half *hidden_state, half *buffer[4], GruParameter *gru_param, int core_mask); .. c:function:: void fp_Gru_p(float *output, float *input, float *weight_g, float *weight_r, float *input_bias, float *state_bias, float *hidden_state, float *buffer[4], GruParameter *gru_param, int core_mask); **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 35 void TestGruL2Fp32(int check_seq_len, int seq_len, int batch_size, int input_size, int bidirectional, int hidden_size, int core_mask) { float *output = (void*)0x10000000; // 私有存储版本地址设置在AM内 float *input = (void*)0x10004000; float *weight_g = (void*)0x10008000; float *weight_r = (void*)0x1000C000; float *input_bias = (void*)0x10010000; float *state_bias = (void*)0x10014000; float *hidden_state = (void*)0x10018000; float** buffer = (float**)0x1001C000; float *output_hidden_state = (void*)0x10020000; GruParameter* param = (GruParameter*)0x10024000; int hidden_state_batch = 1; int num_directions = 1; if (bidirectional) { hidden_state_batch = hidden_state_batch * 2; num_directions = num_directions * 2; } int input_col_align = hidden_size; int state_col_align = hidden_size; memcpy(output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float)); memcpy(check_output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float)); buffer[0] = (void*)0x10030000; buffer[1] = (void*)0x10034000; buffer[2] = (void*)0x10038000; buffer[3] = (void*)0x1003C000; param->batch_ = batch_size; param->bidirectional_ = bidirectional; param->hidden_size_ = hidden_size; param->input_col_align_ = input_col_align; param->input_size_ = input_size; param->output_step_ = batch_size * hidden_size * num_directions; param->seq_len_ = seq_len; param->state_col_align_ = state_col_align; param->check_seq_len_ = check_seq_len; fp_Gru_p(output, input, weight_g, weight_r, input_bias, state_bias, output_hidden_state, buffer, param, core_mask); } void main() { int check_seq_len = 2; int seq_len = 2; int batch_size = 2; int input_size = 2; int bidirectional = 0; int hidden_size = 2; int core_mask = 0b0001; // 私有存储版本只能设置为一个核心启动 TestGruL2Fp32(check_seq_len, seq_len, batch_size, input_size, bidirectional, hidden_size, core_mask); return 0; }